Convolutional neural network

Tutorial on how to train a convolutional neural network to predict protein subcellular localization.


In [1]:
# Import all the necessary modules
import os
os.environ["THEANO_FLAGS"] = "mode=FAST_RUN,optimizer=None,device=cpu,floatX=float32"
import sys
sys.path.insert(0,'..')
import numpy as np
import theano
import theano.tensor as T
import lasagne
from confusionmatrix import ConfusionMatrix
from utils import iterate_minibatches
import matplotlib.pyplot as plt
import time
import itertools
%matplotlib inline

Building the network

The first thing that we have to do is to define the network architecture. Here we are going to use an input layer, two convolutional layers with max pooling, a dense layer and an output layer. These are the steps that we are going to follow:

1.- Specify the hyperparameters of the network:


In [2]:
batch_size = 128
seq_len = 400
n_feat = 20
n_hid = 30
n_class = 10
lr = 0.0025
n_filt = 10
drop_prob = 0.5

2.- Define the input variables to our network:


In [3]:
# We use ftensor3 because the protein data is a 3D-matrix in float32 
input_var = T.ftensor3('inputs')
# ivector because the labels is a single dimensional vector of integers
target_var = T.ivector('targets')

# Dummy data to check the size of the layers during the building of the network
X = np.random.randint(0,10,size=(batch_size,seq_len,n_feat)).astype('float32')
Xmask = np.ones((batch_size,seq_len)).astype('float32')

3.- Define the layers of the network:


In [4]:
# Input layer, holds the shape of the data
l_in = lasagne.layers.InputLayer(shape=(batch_size, seq_len, n_feat), input_var=input_var, name='Input')
print('Input layer: {}'.format(
    lasagne.layers.get_output(l_in, inputs={l_in: input_var}).eval({input_var: X}).shape))

# Shuffle shape to be properly read by the CNN layer
l_shu = lasagne.layers.DimshuffleLayer(l_in, (0,2,1))

print('DimshuffleLayer layer: {}'.format(
    lasagne.layers.get_output(l_shu, inputs={l_in: input_var}).eval({input_var: X}).shape))

# Convolutional layers with different filter size
l_conv_a = lasagne.layers.Conv1DLayer(l_shu, num_filters=n_filt, pad='same', stride=1, 
                                      filter_size=3, nonlinearity=lasagne.nonlinearities.rectify)
print('Convolutional layer size 3: {}'.format(
    lasagne.layers.get_output(l_conv_a, inputs={l_in: input_var}).eval({input_var: X}).shape))

l_conv_b = lasagne.layers.Conv1DLayer(l_shu, num_filters=n_filt, pad='same', stride=1, 
                                      filter_size=5, nonlinearity=lasagne.nonlinearities.rectify)
print('Convolutional layer size 5: {}'.format(
    lasagne.layers.get_output(l_conv_b, inputs={l_in: input_var}).eval({input_var: X}).shape))

# The output is concatenated
l_conc = lasagne.layers.ConcatLayer([l_conv_a, l_conv_b], axis=1)
print('Concatenated convolutional layers: {}'.format(
    lasagne.layers.get_output(l_conc, inputs={l_in: input_var}).eval({input_var: X}).shape))

# Second CNN layer
l_conv_final = lasagne.layers.Conv1DLayer(l_conc, num_filters=n_filt*2, pad='same', 
                                          stride=1, filter_size=3, 
                                          nonlinearity=lasagne.nonlinearities.rectify)
print('Final convolutional layer: {}'.format(
    lasagne.layers.get_output(l_conv_final, inputs={l_in: input_var}).eval({input_var: X}).shape))

# Max pooling is performed to downsample the input and reduce its dimensionality
final_max_pool = lasagne.layers.MaxPool1DLayer(l_conv_final, 5)
print('Max pool layer: {}'.format(
    lasagne.layers.get_output(final_max_pool, inputs={l_in: input_var}).eval({input_var: X}).shape))

# Dense layer with ReLu activation function
l_dense = lasagne.layers.DenseLayer(final_max_pool, num_units=n_hid, name="Dense",
                                    nonlinearity=lasagne.nonlinearities.rectify)
print('Dense layer: {}'.format(
    lasagne.layers.get_output(l_dense, inputs={l_in: input_var}).eval({input_var: X}).shape))

# Output layer with a Softmax activation function
l_out = lasagne.layers.DenseLayer(lasagne.layers.dropout(l_dense, p=drop_prob), num_units=n_class, name="Softmax", 
                                  nonlinearity=lasagne.nonlinearities.softmax)
print('Output layer: {}'.format(
    lasagne.layers.get_output(l_out, inputs={l_in: input_var}).eval({input_var: X}).shape))


Input layer: (128, 400, 20)
DimshuffleLayer layer: (128, 20, 400)
Convolutional layer size 3: (128, 10, 400)
Convolutional layer size 5: (128, 10, 400)
Concatenated convolutional layers: (128, 20, 400)
Final convolutional layer: (128, 20, 400)
Final max pool layer: (128, 20, 80)
Dense layer: (128, 30)
Output layer: (128, 10)

4.- Calculate the prediction and network loss for the training set and update the network weights:


In [5]:
# Get output training, deterministic=False is used for training
prediction = lasagne.layers.get_output(l_out, inputs={l_in: input_var}, deterministic=False)

# Calculate the categorical cross entropy between the labels and the prediction
t_loss = T.nnet.categorical_crossentropy(prediction, target_var)

# Training loss
loss = T.mean(t_loss)

# Parameters
params = lasagne.layers.get_all_params([l_out], trainable=True)

# Get the network gradients and perform total norm constraint normalization
all_grads = lasagne.updates.total_norm_constraint(T.grad(loss, params),3)

# Update parameters using ADAM 
updates = lasagne.updates.adam(all_grads, params, learning_rate=lr)

5.- Calculate the prediction and network loss for the validation set:


In [6]:
# Get output validation, deterministic=True is only use for validation
val_prediction = lasagne.layers.get_output(l_out, inputs={l_in: input_var}, deterministic=True)

# Calculate the categorical cross entropy between the labels and the prediction
t_val_loss = lasagne.objectives.categorical_crossentropy(val_prediction, target_var)

# Validation loss 
val_loss = T.mean(t_val_loss)

6.- Build theano functions:


In [7]:
# Build functions
train_fn = theano.function([input_var, target_var], [loss, prediction], updates=updates)
val_fn = theano.function([input_var, target_var], [val_loss, val_prediction])

Load dataset

Once that the network is built, the next step is to load the training and the validation set


In [8]:
# Load the encoded protein sequences, labels and masks
# The masks are not needed for the FFN or CNN models
train = np.load('data/reduced_train.npz')
X_train = train['X_train']
y_train = train['y_train']
mask_train = train['mask_train']
print(X_train.shape)


(2423, 400, 20)

In [9]:
validation = np.load('data/reduced_val.npz')
X_val = validation['X_val']
y_val = validation['y_val']
mask_val = validation['mask_val']
print(X_val.shape)


(635, 400, 20)

Training

Once that the data is ready and the network compiled we can start with the training of the model. Here we define the number of epochs that we want to perform


In [10]:
# Number of epochs
num_epochs = 80

# Lists to save loss and accuracy of each epoch
loss_training = []
loss_validation = []
acc_training = []
acc_validation = []
start_time = time.time()
min_val_loss = float("inf")

# Start training 
for epoch in range(num_epochs):
    
    # Full pass training set
    train_err = 0
    train_batches = 0
    confusion_train = ConfusionMatrix(n_class)

    # Generate minibatches and train on each one of them
    for batch in iterate_minibatches(X_train.astype(np.float32), y_train.astype(np.int32), 
                                     mask_train.astype(np.float32), batch_size, shuffle=True, sort_len=False):
        # Inputs to the network
        inputs, targets, in_masks = batch
        # Calculate loss and prediction
        tr_err, predict = train_fn(inputs, targets)
        train_err += tr_err
        train_batches += 1
        # Get the predicted class, the one with the maximum likelihood
        preds = np.argmax(predict, axis=-1)
        confusion_train.batch_add(targets, preds)
    
    # Average loss and accuracy
    train_loss = train_err / train_batches
    train_accuracy = confusion_train.accuracy()
    cf_train = confusion_train.ret_mat()

    val_err = 0
    val_batches = 0
    confusion_valid = ConfusionMatrix(n_class)

    # Generate minibatches and validate on each one of them, same procedure as before
    for batch in iterate_minibatches(X_val.astype(np.float32), y_val.astype(np.int32), 
                                     mask_val.astype(np.float32), batch_size, shuffle=True, sort_len=False):
        inputs, targets, in_masks = batch
        err, predict_val = val_fn(inputs, targets)
        val_err += err
        val_batches += 1
        preds = np.argmax(predict_val, axis=-1)
        confusion_valid.batch_add(targets, preds)

    val_loss = val_err / val_batches
    val_accuracy = confusion_valid.accuracy()
    cf_val = confusion_valid.ret_mat()
    
    loss_training.append(train_loss)
    loss_validation.append(val_loss)
    acc_training.append(train_accuracy)
    acc_validation.append(val_accuracy)
    
    # Save the model parameters at the epoch with the lowest validation loss
    if min_val_loss > val_loss:
        min_val_loss = val_loss
        np.savez('params/CNN_params.npz', *lasagne.layers.get_all_param_values(l_out))
    
    print("Epoch {} of {} time elapsed {:.3f}s".format(epoch + 1, num_epochs, time.time() - start_time))
    print("  training loss:\t\t{:.6f}".format(train_loss))
    print("  validation loss:\t\t{:.6f}".format(val_loss))
    print("  training accuracy:\t\t{:.2f} %".format(train_accuracy * 100))
    print("  validation accuracy:\t\t{:.2f} %".format(val_accuracy * 100))


Epoch 1 of 80 time elapsed 2.737s
  training loss:		2.143875
  validation loss:		1.974160
  training accuracy:		21.71 %
  validation accuracy:		29.84 %
Epoch 2 of 80 time elapsed 5.094s
  training loss:		1.994283
  validation loss:		1.888495
  training accuracy:		30.10 %
  validation accuracy:		34.84 %
Epoch 3 of 80 time elapsed 7.632s
  training loss:		1.888262
  validation loss:		1.750514
  training accuracy:		34.42 %
  validation accuracy:		40.16 %
Epoch 4 of 80 time elapsed 10.360s
  training loss:		1.717869
  validation loss:		1.563852
  training accuracy:		41.41 %
  validation accuracy:		49.22 %
Epoch 5 of 80 time elapsed 12.122s
  training loss:		1.550515
  validation loss:		1.369192
  training accuracy:		48.93 %
  validation accuracy:		58.75 %
Epoch 6 of 80 time elapsed 13.901s
  training loss:		1.363479
  validation loss:		1.222367
  training accuracy:		54.11 %
  validation accuracy:		66.72 %
Epoch 7 of 80 time elapsed 15.678s
  training loss:		1.286129
  validation loss:		1.142994
  training accuracy:		55.92 %
  validation accuracy:		67.50 %
Epoch 8 of 80 time elapsed 17.670s
  training loss:		1.186357
  validation loss:		1.083477
  training accuracy:		60.61 %
  validation accuracy:		67.34 %
Epoch 9 of 80 time elapsed 19.440s
  training loss:		1.118458
  validation loss:		0.977686
  training accuracy:		61.31 %
  validation accuracy:		70.16 %
Epoch 10 of 80 time elapsed 21.263s
  training loss:		1.075366
  validation loss:		0.967645
  training accuracy:		62.62 %
  validation accuracy:		71.56 %
Epoch 11 of 80 time elapsed 23.257s
  training loss:		1.013794
  validation loss:		0.893697
  training accuracy:		64.51 %
  validation accuracy:		72.50 %
Epoch 12 of 80 time elapsed 25.644s
  training loss:		0.990296
  validation loss:		0.850944
  training accuracy:		65.17 %
  validation accuracy:		73.59 %
Epoch 13 of 80 time elapsed 27.602s
  training loss:		0.927670
  validation loss:		0.841530
  training accuracy:		66.74 %
  validation accuracy:		72.66 %
Epoch 14 of 80 time elapsed 29.476s
  training loss:		0.907421
  validation loss:		0.807145
  training accuracy:		66.82 %
  validation accuracy:		74.38 %
Epoch 15 of 80 time elapsed 31.692s
  training loss:		0.841751
  validation loss:		0.787659
  training accuracy:		69.98 %
  validation accuracy:		74.22 %
Epoch 16 of 80 time elapsed 33.914s
  training loss:		0.837218
  validation loss:		0.792206
  training accuracy:		68.54 %
  validation accuracy:		74.38 %
Epoch 17 of 80 time elapsed 36.001s
  training loss:		0.796217
  validation loss:		0.752985
  training accuracy:		70.89 %
  validation accuracy:		75.16 %
Epoch 18 of 80 time elapsed 37.844s
  training loss:		0.792815
  validation loss:		0.731923
  training accuracy:		70.68 %
  validation accuracy:		74.69 %
Epoch 19 of 80 time elapsed 39.904s
  training loss:		0.758300
  validation loss:		0.729886
  training accuracy:		70.68 %
  validation accuracy:		76.09 %
Epoch 20 of 80 time elapsed 41.687s
  training loss:		0.767264
  validation loss:		0.727392
  training accuracy:		70.81 %
  validation accuracy:		75.62 %
Epoch 21 of 80 time elapsed 43.374s
  training loss:		0.735753
  validation loss:		0.736207
  training accuracy:		70.60 %
  validation accuracy:		76.09 %
Epoch 22 of 80 time elapsed 45.064s
  training loss:		0.722486
  validation loss:		0.729975
  training accuracy:		72.08 %
  validation accuracy:		77.03 %
Epoch 23 of 80 time elapsed 47.229s
  training loss:		0.684312
  validation loss:		0.724368
  training accuracy:		72.78 %
  validation accuracy:		77.19 %
Epoch 24 of 80 time elapsed 49.269s
  training loss:		0.654248
  validation loss:		0.726558
  training accuracy:		74.01 %
  validation accuracy:		75.78 %
Epoch 25 of 80 time elapsed 50.969s
  training loss:		0.645080
  validation loss:		0.729454
  training accuracy:		74.14 %
  validation accuracy:		76.56 %
Epoch 26 of 80 time elapsed 52.661s
  training loss:		0.636024
  validation loss:		0.748639
  training accuracy:		74.05 %
  validation accuracy:		77.66 %
Epoch 27 of 80 time elapsed 54.362s
  training loss:		0.628844
  validation loss:		0.725967
  training accuracy:		74.34 %
  validation accuracy:		78.44 %
Epoch 28 of 80 time elapsed 56.309s
  training loss:		0.606186
  validation loss:		0.711861
  training accuracy:		75.53 %
  validation accuracy:		78.44 %
Epoch 29 of 80 time elapsed 58.282s
  training loss:		0.584276
  validation loss:		0.763566
  training accuracy:		76.11 %
  validation accuracy:		78.12 %
Epoch 30 of 80 time elapsed 60.058s
  training loss:		0.574085
  validation loss:		0.732832
  training accuracy:		76.11 %
  validation accuracy:		78.59 %
Epoch 31 of 80 time elapsed 61.864s
  training loss:		0.580036
  validation loss:		0.747884
  training accuracy:		76.19 %
  validation accuracy:		77.50 %
Epoch 32 of 80 time elapsed 63.767s
  training loss:		0.558126
  validation loss:		0.767443
  training accuracy:		76.52 %
  validation accuracy:		78.91 %
Epoch 33 of 80 time elapsed 65.740s
  training loss:		0.571885
  validation loss:		0.748526
  training accuracy:		76.48 %
  validation accuracy:		78.12 %
Epoch 34 of 80 time elapsed 67.446s
  training loss:		0.527101
  validation loss:		0.743367
  training accuracy:		77.71 %
  validation accuracy:		79.84 %
Epoch 35 of 80 time elapsed 69.153s
  training loss:		0.518533
  validation loss:		0.748425
  training accuracy:		78.87 %
  validation accuracy:		81.72 %
Epoch 36 of 80 time elapsed 70.882s
  training loss:		0.521540
  validation loss:		0.809446
  training accuracy:		78.58 %
  validation accuracy:		79.38 %
Epoch 37 of 80 time elapsed 72.601s
  training loss:		0.530301
  validation loss:		0.818905
  training accuracy:		78.37 %
  validation accuracy:		80.94 %
Epoch 38 of 80 time elapsed 74.331s
  training loss:		0.492963
  validation loss:		0.767090
  training accuracy:		79.69 %
  validation accuracy:		80.47 %
Epoch 39 of 80 time elapsed 76.074s
  training loss:		0.461024
  validation loss:		0.829002
  training accuracy:		80.14 %
  validation accuracy:		80.78 %
Epoch 40 of 80 time elapsed 77.811s
  training loss:		0.491934
  validation loss:		0.783026
  training accuracy:		79.56 %
  validation accuracy:		82.03 %
Epoch 41 of 80 time elapsed 79.544s
  training loss:		0.454656
  validation loss:		0.803371
  training accuracy:		80.55 %
  validation accuracy:		80.78 %
Epoch 42 of 80 time elapsed 81.289s
  training loss:		0.430194
  validation loss:		0.836164
  training accuracy:		82.48 %
  validation accuracy:		80.62 %
Epoch 43 of 80 time elapsed 83.024s
  training loss:		0.440897
  validation loss:		0.905834
  training accuracy:		81.37 %
  validation accuracy:		79.53 %
Epoch 44 of 80 time elapsed 84.776s
  training loss:		0.412019
  validation loss:		0.853931
  training accuracy:		82.57 %
  validation accuracy:		81.09 %
Epoch 45 of 80 time elapsed 86.495s
  training loss:		0.434522
  validation loss:		0.870272
  training accuracy:		81.46 %
  validation accuracy:		81.88 %
Epoch 46 of 80 time elapsed 88.213s
  training loss:		0.440023
  validation loss:		0.850589
  training accuracy:		81.09 %
  validation accuracy:		80.00 %
Epoch 47 of 80 time elapsed 89.935s
  training loss:		0.389817
  validation loss:		0.841705
  training accuracy:		83.72 %
  validation accuracy:		80.78 %
Epoch 48 of 80 time elapsed 91.653s
  training loss:		0.383544
  validation loss:		0.941057
  training accuracy:		83.76 %
  validation accuracy:		81.25 %
Epoch 49 of 80 time elapsed 93.371s
  training loss:		0.404868
  validation loss:		0.881543
  training accuracy:		83.02 %
  validation accuracy:		81.72 %
Epoch 50 of 80 time elapsed 95.094s
  training loss:		0.375055
  validation loss:		0.905357
  training accuracy:		84.62 %
  validation accuracy:		81.56 %
Epoch 51 of 80 time elapsed 96.806s
  training loss:		0.381056
  validation loss:		0.912396
  training accuracy:		83.88 %
  validation accuracy:		82.50 %
Epoch 52 of 80 time elapsed 98.597s
  training loss:		0.395058
  validation loss:		0.873037
  training accuracy:		83.35 %
  validation accuracy:		82.03 %
Epoch 53 of 80 time elapsed 100.603s
  training loss:		0.407845
  validation loss:		0.856693
  training accuracy:		81.70 %
  validation accuracy:		82.03 %
Epoch 54 of 80 time elapsed 102.566s
  training loss:		0.371751
  validation loss:		0.934208
  training accuracy:		84.13 %
  validation accuracy:		82.19 %
Epoch 55 of 80 time elapsed 104.471s
  training loss:		0.367710
  validation loss:		0.959202
  training accuracy:		84.09 %
  validation accuracy:		82.66 %
Epoch 56 of 80 time elapsed 106.477s
  training loss:		0.367305
  validation loss:		0.903137
  training accuracy:		83.84 %
  validation accuracy:		80.94 %
Epoch 57 of 80 time elapsed 108.383s
  training loss:		0.354351
  validation loss:		1.004953
  training accuracy:		83.92 %
  validation accuracy:		80.78 %
Epoch 58 of 80 time elapsed 110.312s
  training loss:		0.351139
  validation loss:		0.892682
  training accuracy:		85.32 %
  validation accuracy:		81.88 %
Epoch 59 of 80 time elapsed 112.252s
  training loss:		0.372980
  validation loss:		0.958549
  training accuracy:		84.17 %
  validation accuracy:		82.66 %
Epoch 60 of 80 time elapsed 114.248s
  training loss:		0.348411
  validation loss:		0.903453
  training accuracy:		84.91 %
  validation accuracy:		81.72 %
Epoch 61 of 80 time elapsed 116.155s
  training loss:		0.361235
  validation loss:		0.950520
  training accuracy:		84.21 %
  validation accuracy:		81.88 %
Epoch 62 of 80 time elapsed 118.271s
  training loss:		0.339186
  validation loss:		1.004408
  training accuracy:		85.32 %
  validation accuracy:		81.56 %
Epoch 63 of 80 time elapsed 120.169s
  training loss:		0.333286
  validation loss:		1.009766
  training accuracy:		86.43 %
  validation accuracy:		81.25 %
Epoch 64 of 80 time elapsed 122.129s
  training loss:		0.332604
  validation loss:		0.979457
  training accuracy:		85.98 %
  validation accuracy:		82.97 %
Epoch 65 of 80 time elapsed 124.111s
  training loss:		0.337465
  validation loss:		1.063238
  training accuracy:		85.73 %
  validation accuracy:		83.12 %
Epoch 66 of 80 time elapsed 125.986s
  training loss:		0.329280
  validation loss:		0.998134
  training accuracy:		85.90 %
  validation accuracy:		80.78 %
Epoch 67 of 80 time elapsed 127.757s
  training loss:		0.347367
  validation loss:		1.004568
  training accuracy:		84.66 %
  validation accuracy:		82.19 %
Epoch 68 of 80 time elapsed 129.506s
  training loss:		0.331375
  validation loss:		1.088719
  training accuracy:		86.27 %
  validation accuracy:		82.19 %
Epoch 69 of 80 time elapsed 131.327s
  training loss:		0.328786
  validation loss:		1.061946
  training accuracy:		85.36 %
  validation accuracy:		81.56 %
Epoch 70 of 80 time elapsed 133.326s
  training loss:		0.319111
  validation loss:		1.057594
  training accuracy:		86.02 %
  validation accuracy:		82.50 %
Epoch 71 of 80 time elapsed 135.179s
  training loss:		0.334524
  validation loss:		1.054009
  training accuracy:		85.07 %
  validation accuracy:		81.88 %
Epoch 72 of 80 time elapsed 137.218s
  training loss:		0.317473
  validation loss:		1.099902
  training accuracy:		86.39 %
  validation accuracy:		82.19 %
Epoch 73 of 80 time elapsed 139.236s
  training loss:		0.318518
  validation loss:		1.031948
  training accuracy:		86.31 %
  validation accuracy:		82.03 %
Epoch 74 of 80 time elapsed 141.073s
  training loss:		0.327583
  validation loss:		1.115477
  training accuracy:		85.86 %
  validation accuracy:		81.56 %
Epoch 75 of 80 time elapsed 142.844s
  training loss:		0.329498
  validation loss:		1.127070
  training accuracy:		84.91 %
  validation accuracy:		82.34 %
Epoch 76 of 80 time elapsed 144.718s
  training loss:		0.347947
  validation loss:		1.162104
  training accuracy:		84.42 %
  validation accuracy:		79.53 %
Epoch 77 of 80 time elapsed 146.542s
  training loss:		0.317609
  validation loss:		1.135417
  training accuracy:		86.43 %
  validation accuracy:		81.72 %
Epoch 78 of 80 time elapsed 148.335s
  training loss:		0.315220
  validation loss:		1.090182
  training accuracy:		86.14 %
  validation accuracy:		82.50 %
Epoch 79 of 80 time elapsed 150.335s
  training loss:		0.294261
  validation loss:		1.139297
  training accuracy:		86.68 %
  validation accuracy:		82.03 %
Epoch 80 of 80 time elapsed 152.169s
  training loss:		0.312373
  validation loss:		1.155346
  training accuracy:		86.14 %
  validation accuracy:		82.34 %

In [11]:
print("Minimum validation loss: {:.6f}".format(min_val_loss))


Minimum validation loss: 0.711861

Model loss and accuracy

Here we plot the loss and the accuracy for the training and validation set at each epoch.


In [12]:
x_axis = range(num_epochs)
plt.figure(figsize=(8,6))
plt.plot(x_axis,loss_training)
plt.plot(x_axis,loss_validation)
plt.xlabel('Epoch')
plt.ylabel('Error')
plt.legend(('Training','Validation'));



In [13]:
plt.figure(figsize=(8,6))
plt.plot(x_axis,acc_training)
plt.plot(x_axis,acc_validation)
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(('Training','Validation'));


Confusion matrix

The confusion matrix allows us to visualize how well is predicted each class and which are the most common misclassifications.


In [14]:
# Plot confusion matrix 
# Code based on http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html

plt.figure(figsize=(8,8))
cmap=plt.cm.Blues   
plt.imshow(cf_val, interpolation='nearest', cmap=cmap)
plt.title('Confusion matrix validation set')
plt.colorbar()
tick_marks = np.arange(n_class)
classes = ['Nucleus','Cytoplasm','Extracellular','Mitochondrion','Cell membrane','ER',
           'Chloroplast','Golgi apparatus','Lysosome','Vacuole']

plt.xticks(tick_marks, classes, rotation=60)
plt.yticks(tick_marks, classes)

thresh = cf_val.max() / 2.
for i, j in itertools.product(range(cf_val.shape[0]), range(cf_val.shape[1])):
    plt.text(j, i, cf_val[i, j],
             horizontalalignment="center",
             color="white" if cf_val[i, j] > thresh else "black")

plt.tight_layout()
plt.ylabel('True location')
plt.xlabel('Predicted location');



In [ ]: